/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.solr.client.solrj.io.stream;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import org.apache.solr.client.solrj.SolrRequest;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
import org.apache.solr.client.solrj.io.comp.FieldComparator;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.Bucket;
import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;

public class Facet2DStream extends TupleStream implements Expressible {

  private static final long serialVersionUID = 1;

  private String collection;
  private ModifiableSolrParams params;
  private Bucket x;
  private Bucket y;
  private Metric metric;
  private String zkHost;
  private Iterator<Tuple> out;
  private List<Tuple> tuples = new ArrayList<Tuple>();
  private int dimensionX;
  private int dimensionY;
  private FieldComparator bucketSort;

  protected transient SolrClientCache clientCache;
  private transient boolean doCloseCache;

  public Facet2DStream(
      String zkHost,
      String collection,
      ModifiableSolrParams params,
      Bucket x,
      Bucket y,
      String dimensions,
      Metric metric)
      throws IOException {
    if (dimensions != null) {
      String[] strDimensions = dimensions.split(",");
      if (strDimensions.length != 2) {
        throw new IOException(
            String.format(
                Locale.ROOT, "two dimension values expected. %s found", strDimensions.length));
      }
      this.dimensionX = Integer.parseInt(strDimensions[0]);
      this.dimensionY = Integer.parseInt(strDimensions[1]);
    }
    String bucketSortString = metric.getIdentifier() + " desc";
    this.bucketSort = parseBucketSort(bucketSortString, x, y);

    init(collection, params, x, y, bucketSort, dimensionX, dimensionY, metric, zkHost);
  }

  public Facet2DStream(StreamExpression expression, StreamFactory factory) throws IOException {

    String collectionName = factory.getValueOperand(expression, 0);

    if (collectionName.indexOf('"') > -1) {
      collectionName = collectionName.replace("\"", "").replace(" ", "");
    }

    List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
    StreamExpressionNamedParameter bucketXExpression = factory.getNamedOperand(expression, "x");
    StreamExpressionNamedParameter bucketYExpression = factory.getNamedOperand(expression, "y");
    StreamExpressionNamedParameter dimensionsExpression =
        factory.getNamedOperand(expression, "dimensions");
    List<StreamExpression> metricExpression =
        factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, Metric.class);
    StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");

    if (collectionName == null) {
      throw new IOException(
          String.format(
              Locale.ROOT,
              "invalid expression %s - collectionName expected as first operand",
              expression));
    }

    ModifiableSolrParams params = new ModifiableSolrParams();
    for (StreamExpressionNamedParameter namedParam : namedParams) {
      if (!namedParam.getName().equals("x")
          && !namedParam.getName().equals("y")
          && !namedParam.getName().equals("dimensions")
          && !namedParam.getName().equals("zkHost")) {
        params.add(namedParam.getName(), namedParam.getParameter().toString().trim());
      }
    }

    if (params.get("q") == null) {
      params.set("q", "*:*");
    }

    Bucket x = null;
    if (bucketXExpression != null) {
      if (bucketXExpression.getParameter() instanceof StreamExpressionValue) {
        String keyX = ((StreamExpressionValue) bucketXExpression.getParameter()).getValue();
        if (keyX != null && !keyX.isEmpty()) {
          x = new Bucket(keyX.trim());
        }
      }
    }
    Bucket y = null;
    if (bucketYExpression != null) {
      if (bucketYExpression.getParameter() instanceof StreamExpressionValue) {
        String keyY = ((StreamExpressionValue) bucketYExpression.getParameter()).getValue();
        if (keyY != null && !keyY.isEmpty()) {
          y = new Bucket(keyY.trim());
        }
      }
    }

    if (x == null || y == null) {
      throw new IOException(
          String.format(
              Locale.ROOT,
              "invalid expression %s - x and y buckets expected. eg. 'x=\"name\"'",
              expression));
    }

    Metric metric = null;
    if (metricExpression == null || metricExpression.size() == 0) {
      metric = new CountMetric();
    } else {
      metric = factory.constructMetric(metricExpression.get(0));
    }

    String bucketSortString = metric.getIdentifier() + " desc";
    FieldComparator bucketSort = parseBucketSort(bucketSortString, x, y);

    int dimensionX = 10;
    int dimensionY = 10;
    if (dimensionsExpression != null) {
      if (dimensionsExpression.getParameter() instanceof StreamExpressionValue) {
        String[] strDimensions =
            ((StreamExpressionValue) dimensionsExpression.getParameter()).getValue().split(",");
        if (strDimensions.length != 2) {
          throw new IOException(
              String.format(
                  Locale.ROOT, "two dimension values expected. Found %s", strDimensions.length));
        }
        dimensionX = Integer.parseInt(strDimensions[0].trim());
        dimensionY = Integer.parseInt(strDimensions[1].trim());
      }
    }

    String zkHost = null;
    if (zkHostExpression == null) {
      zkHost = factory.getCollectionZkHost(collectionName);
      if (zkHost == null) {
        zkHost = factory.getDefaultZkHost();
      }
    } else if (zkHostExpression.getParameter() instanceof StreamExpressionValue) {
      zkHost = ((StreamExpressionValue) zkHostExpression.getParameter()).getValue();
    }

    if (zkHost == null) {
      throw new IOException(
          String.format(
              Locale.ROOT,
              "invalid expression %s - zkHost not found for collection '%s'",
              expression,
              collectionName));
    }

    init(collectionName, params, x, y, bucketSort, dimensionX, dimensionY, metric, zkHost);
  }

  private FieldComparator parseBucketSort(String bucketSortString, Bucket x, Bucket y)
      throws IOException {
    String[] spec = bucketSortString.trim().split("\\s+");

    String fieldName = spec[0].trim();
    return new FieldComparator(fieldName, ComparatorOrder.DESCENDING);
  }

  private void init(
      String collection,
      SolrParams params,
      Bucket x,
      Bucket y,
      FieldComparator bucketSort,
      int dimensionX,
      int dimensionY,
      Metric metric,
      String zkHost) {
    this.collection = collection;
    this.params = new ModifiableSolrParams(params);
    this.x = x;
    this.y = y;
    this.dimensionX = dimensionX;
    this.dimensionY = dimensionY;
    this.metric = metric;
    this.bucketSort = bucketSort;
    this.zkHost = zkHost;
  }

  @Override
  public Explanation toExplanation(StreamFactory factory) throws IOException {
    StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
    explanation.setFunctionName(factory.getFunctionName(this.getClass()));
    explanation.setImplementingClass(this.getClass().getName());
    explanation.setExpressionType(ExpressionType.STREAM_SOURCE);
    explanation.setExpression(toExpression(factory).toString());

    StreamExplanation child = new StreamExplanation(getStreamNodeId() + "-datastore");
    child.setFunctionName(String.format(Locale.ROOT, "solr (%s)", collection));
    child.setImplementingClass("Solr/Lucene");
    child.setExpressionType(ExpressionType.DATASTORE);

    child.setExpression(
        params.stream()
            .map(
                e -> String.format(Locale.ROOT, "%s=%s", e.getKey(), Arrays.toString(e.getValue())))
            .collect(Collectors.joining(",")));
    explanation.addChild(child);

    return explanation;
  }

  @Override
  public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
    StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));

    // collection
    if (collection.indexOf(',') > -1) {
      expression.addParameter("\"" + collection + "\"");
    } else {
      expression.addParameter(collection);
    }

    // parameters for q,fl etc
    for (Entry<String, String[]> param : params.getMap().entrySet()) {
      for (String val : param.getValue()) {
        expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), val));
      }
    }

    // bucket x
    {
      StringBuilder builder = new StringBuilder();

      builder.append(x.toString());
      expression.addParameter(new StreamExpressionNamedParameter("x", builder.toString()));
    }

    // bucket y
    {
      StringBuilder builder = new StringBuilder();

      builder.append(y.toString());
      expression.addParameter(new StreamExpressionNamedParameter("y", builder.toString()));
    }

    // dimensions
    expression.addParameter(
        new StreamExpressionNamedParameter(
            "dimensions", Integer.toString(dimensionX) + "," + Integer.toString(dimensionY)));

    // metric
    expression.addParameter(metric.toExpression(factory));

    // zkHost
    expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));

    return expression;
  }

  @Override
  public void setStreamContext(StreamContext context) {
    clientCache = context.getSolrClientCache();
  }

  @Override
  public List<TupleStream> children() {
    return new ArrayList<>();
  }

  @Override
  public void open() throws IOException {
    if (clientCache == null) {
      doCloseCache = true;
      clientCache = new SolrClientCache();
    } else {
      doCloseCache = false;
    }
    FieldComparator[] adjustedSorts = adjustSorts(x, y, bucketSort);

    String json = getJsonFacetString(x, y, metric, adjustedSorts, dimensionX, dimensionY);
    // assert expectedJson(json);

    ModifiableSolrParams paramsLoc = new ModifiableSolrParams(params);
    paramsLoc.set("json.facet", json);
    paramsLoc.set("rows", "0");

    QueryRequest request = new QueryRequest(paramsLoc, SolrRequest.METHOD.POST);
    try {
      var cloudSolrClient = clientCache.getCloudSolrClient(zkHost);
      NamedList<Object> response = cloudSolrClient.request(request, collection);
      getTuples(response, x, y, metric);
      this.out = tuples.iterator();

    } catch (Exception e) {
      throw new IOException(e);
    }
  }

  @Override
  public Tuple read() throws IOException {
    if (out.hasNext()) {
      return out.next();
    } else {
      Tuple tuple = Tuple.EOF();
      tuple.put("rows", tuples.size());
      return tuple;
    }
  }

  @Override
  public void close() throws IOException {
    if (doCloseCache) {
      clientCache.close();
    }
  }

  private String getJsonFacetString(
      Bucket x,
      Bucket y,
      Metric metric,
      FieldComparator[] adjustedSorts,
      int dimensionX,
      int dimensionY) {
    StringBuilder buf = new StringBuilder();
    appendJson(buf, x, y, metric, adjustedSorts, dimensionX, dimensionY);

    return "{" + buf.toString() + "}";
  }

  private FieldComparator[] adjustSorts(Bucket x, Bucket y, FieldComparator bucketSort)
      throws IOException {
    FieldComparator[] adjustSorts = new FieldComparator[2];
    if (bucketSort.getLeftFieldName().contains("(")) {
      for (int i = 0; i < adjustSorts.length; i++) {
        adjustSorts[i] = bucketSort;
      }
    } else {
      adjustSorts[0] = new FieldComparator(x.toString(), bucketSort.getOrder());
      adjustSorts[1] = new FieldComparator(y.toString(), bucketSort.getOrder());
    }
    return adjustSorts;
  }

  private void appendJson(
      StringBuilder buf,
      Bucket x,
      Bucket y,
      Metric metric,
      FieldComparator[] adjustedSorts,
      int dimensionX,
      int dimensionY) {

    buf.append('"');
    buf.append("x");
    buf.append('"');
    buf.append(":{");
    buf.append("\"type\":\"terms\"");
    buf.append(",\"field\":\"").append(x.toString()).append('"');
    buf.append(",\"limit\":").append(dimensionX);
    buf.append(",\"overrequest\":1000");
    String fsort = getFacetSort(adjustedSorts[0].getLeftFieldName(), metric);
    buf.append(",\"sort\":\"").append(fsort).append(" desc\"");
    buf.append(",\"facet\":{");

    String identifier = metric.getIdentifier();
    if (!identifier.startsWith("count(")) {
      buf.append("\"agg\":\"").append(identifier).append('"');
      buf.append(",");
    }
    buf.append('"');
    buf.append("y");
    buf.append('"');
    buf.append(":{");
    buf.append("\"type\":\"terms\"");
    buf.append(",\"field\":\"").append(y.toString()).append('"');
    buf.append(",\"limit\":").append(dimensionY);
    buf.append(",\"overrequest\":1000");
    String fsortY = getFacetSort(adjustedSorts[1].getLeftFieldName(), metric);
    buf.append(",\"sort\":\"").append(fsortY).append(" desc\"");
    buf.append(",\"facet\":{");
    if (!identifier.startsWith("count(")) {
      buf.append("\"agg\":\"").append(identifier).append('"');
    }
    buf.append("}}}}");
  }

  private String getFacetSort(String id, Metric metric) {
    if (id.startsWith("count(")) {
      return "count";
    } else if (id.equals(metric.getIdentifier())) {
      return "agg";
    }

    return null;
  }

  private void getTuples(NamedList<?> response, Bucket x, Bucket y, Metric metric) {
    Tuple tuple = new Tuple();
    NamedList<?> facets = (NamedList<?>) response.get("facets");
    fillTuples(0, tuples, tuple, facets, x, y, metric);
  }

  private void fillTuples(
      int level,
      List<Tuple> tuples,
      Tuple currentTuple,
      NamedList<?> facets,
      Bucket x,
      Bucket y,
      Metric metric) {
    String bucketXName = x.toString();
    String bucketYName = y.toString();

    NamedList<?> allXBuckets = (NamedList<?>) facets.get("x");
    for (int b = 0; b < allXBuckets.size(); b++) {
      List<?> buckets = (List<?>) allXBuckets.get("buckets");
      for (int s = 0; s < buckets.size(); s++) {

        NamedList<?> bucket = (NamedList<?>) buckets.get(s);
        Object val = bucket.get("val");
        if (val instanceof Integer) {
          val = ((Integer) val).longValue();
        }
        Tuple tx = currentTuple.clone();
        tx.put(bucketXName, val);

        NamedList<?> allYBuckets = (NamedList<?>) bucket.get("y");
        List<?> ybuckets = (List<?>) allYBuckets.get("buckets");

        for (int d = 0; d < ybuckets.size(); d++) {
          NamedList<?> bucketY = (NamedList<?>) ybuckets.get(d);
          Object valY = bucketY.get("val");
          if (valY instanceof Integer) {
            valY = ((Integer) valY).longValue();
          }
          Tuple yt = tx.clone();
          yt.put(bucketYName, valY);

          int m = 0;
          String identifier = metric.getIdentifier();
          if (!identifier.startsWith("count(")) {
            Number d1 = (Number) bucketY.get("agg");
            if (metric.outputLong) {
              if (d1 instanceof Long || d1 instanceof Integer) {
                yt.put(identifier, d1.longValue());
              } else {
                yt.put(identifier, Math.round(d1.doubleValue()));
              }
            } else {
              yt.put(identifier, d1.doubleValue());
            }
            ++m;
          } else {
            long l = ((Number) bucketY.get("count")).longValue();
            yt.put("count(*)", l);
          }
          tuples.add(yt);
        }
      }
    }
  }

  @Override
  public int getCost() {
    return 0;
  }

  @Override
  public StreamComparator getStreamSort() {
    return bucketSort;
  }

  public String getCollection() {
    return collection;
  }

  public Bucket getX() {
    return x;
  }

  public Bucket getY() {
    return y;
  }
}
